{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Deep Learning with PyTorch\n",
    "\n",
    "In this notebook, we will go over the basics of Deep Learning using Pytorch. Pytorch is a framework which allows manipulations of tensors using constructs similar to numpy. In addition it has modules to allow auto differentiation to carryout back-propagation which forms the backbone of training a neural network\n",
    "\n",
    "## What are Neural Networks\n",
    "\n",
    "Deep learning is based on Artificial Neural networks which are made up of neurons. A neuron takes inputs, calculates the weighted sum and then passes the sum through some kind of non-linear function (called activation function) as shown below:\n",
    "\n",
    "![Neuron](./images/neuron.png \"Neuron\")\n",
    "\n",
    "<br/>\n",
    "<br/>\n",
    "\n",
    "We stack these neurons to make a neural network as shown below:\n",
    "![Neuron](./images/nn.svg \"Neuron\")\n",
    "\n",
    "Let us now create a neural network in PyTorch. We will use this network to train a model to take MNIST data as input and produce the class it belongs to.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST \n",
    "\n",
    "MNIST dataset has images 28x28 pixels  = 784 pixels. \n",
    "\n",
    "We will have 10 units at the output layer to signify the digit (0-9) the image belongs to.\n",
    "\n",
    "Let us first load the data and print some images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a transform to normalize the data\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                              transforms.Normalize(0.5, 0.5),\n",
    "                             ])\n",
    "# Download and load the training data\n",
    "trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)\n",
    "\n",
    "# Download and load the test data\n",
    "testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let us load a batch of training data and checkout its shape\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 1, 28, 28])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images.size()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "image size of `(32,1,28,28)` means that we have 32 images, with each image of size (1x28x28) (channels x height x width)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x21f3a2af7f0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAANTUlEQVR4nO3dX6wc9XnG8eepa7ggAfkU2ToQqNPIQFFFibFMJVtVqigRGCQ7FymxrIiqkY4vjIilomKlF7GoKqA07hVEOlZQXCt1sMSfWFHBMZZVGrAiDOaPiWuDwcSODzbUF8ESIgd4e3HG0cHszB7vzO6s/X4/0mp3592ZebXweGb3N2d/jggBOP/9UdsNABgMwg4kQdiBJAg7kARhB5L440HuzDZf/QN9FhHutLzWkd32TbYP2H7D9ro62wLQX+51nN32LEkHJX1N0lFJz0taGRG/rliHIzvQZ/04si+W9EZEvBkRv5f0U0nLa2wPQB/VCfvlko5Me360WPYptsds77G9p8a+ANRU5wu6TqcKnzlNj4hxSeMSp/FAm+oc2Y9KumLa8y9IOlavHQD9Uifsz0taYPuLti+Q9C1J25ppC0DTej6Nj4iPbN8habukWZIejojXGusMQKN6HnrraWd8Zgf6ri8X1QA4dxB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kMRAp2xGf6xYsaK0tnDhwsp1u/268G233dZLS39w9dVX97zvbduqpyF4+eWXK+t79+4trT3xxBOV656POLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs58DVq1aVVnfuHFjae3DDz+sXPett96qrB85cqSyfuedd1bWb7311tLaypUrK9e98sora9UnJydLaxnH2WuF3fZhSe9L+ljSRxGxqImmADSviSP730TEew1sB0Af8ZkdSKJu2EPSL2y/YHus0wtsj9neY3tPzX0BqKHuafySiDhme66kHbb/NyKemf6CiBiXNC5Jtqv/8gFA39Q6skfEseL+hKTHJS1uoikAzes57LYvsv35048lfV3SvqYaA9CsOqfx8yQ9bvv0dv4zIp5qpCt8ymWXXVZZv/DCC0trS5YsqVy36m++m7B///7S2gMPPNDXfePTeg57RLwp6S8b7AVAHzH0BiRB2IEkCDuQBGEHkiDsQBLu9nO+je6MK+g6uuGGGyrrO3bsqKxfcsklpbVZs2b11BPOXRHhTss5sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvyU9BC45pprKutV4+jATHFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGcfAvfcc0+t9Xfv3t1QJzifcWQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZx8CF198ca3133333Z7XXbBgQWW9ajroug4cOFBZn5yc7Nu+M+p6ZLf9sO0TtvdNWzZie4ft14v7Of1tE0BdMzmN/7Gkm85Ytk7SzohYIGln8RzAEOsa9oh4RtLJMxYvl7SpeLxJ0opm2wLQtF4/s8+LiAlJiogJ23PLXmh7TNJYj/sB0JC+f0EXEeOSxiUmdgTa1OvQ23Hbo5JU3J9oriUA/dBr2LdJur14fLuknzXTDoB+6To/u+0tkr4i6VJJxyV9X9ITkrZKulLSbyR9MyLO/BKv07Y4je+g2zj5yMhIZf3gwYOltX379pXWJOmWW26prNcdZ7c7ThUuSdqwYUPlunfddVetfWdVNj9718/sEbGypPTVWh0BGCgulwWSIOxAEoQdSIKwA0kQdiCJrkNvje6MobeOVq1aVVnfvHlzz9uemJiorC9cuLCyfvz48Z73LUlV/3+98847lesuXbq0sn7o0KGeejrflQ29cWQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgST4KelzwNtvv11Zf/DBB0tr3cbo646jd1M1zj53bumvmUmSFi9eXFlnnP3scGQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZx8CO3furKw//fTTlfV+j5Xj/MCRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJx9CHT7/XSgCV2P7LYftn3C9r5py9bb/q3tl4rbsv62CaCumZzG/1jSTR2W/3tEXF/c/qvZtgA0rWvYI+IZSScH0AuAPqrzBd0dtl8pTvPnlL3I9pjtPbb31NgXgJp6DfsPJX1J0vWSJiT9oOyFETEeEYsiYlGP+wLQgJ7CHhHHI+LjiPhE0kZJ1T8DCqB1PYXd9ui0p9+QtK/stQCGQ9f52W1vkfQVSZdKOi7p+8Xz6yWFpMOSVkdE9UTgYn72jHbv3l1au/HGGyvXfeqppyrry5Yx4ttJ2fzsXS+qiYiVHRb/qHZHAAaKy2WBJAg7kARhB5Ig7EAShB1Igj9xRV9VDa91G/btNt00zg5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2VLrgggsq6/fee++AOkFdHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2VHpqquuqqyvXbt2MI2gNo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wNWLFiRWX97rvvrqzPnz+/sj4+Pl5Z3759e2ntueeeq1x3wYIFlfU1a9ZU1u2OswPPqL5hw4bKdbds2VJZx9npemS3fYXtXbb3237N9neL5SO2d9h+vbif0/92AfRqJqfxH0n6h4j4c0l/JWmN7WslrZO0MyIWSNpZPAcwpLqGPSImIuLF4vH7kvZLulzSckmbipdtkrSiTz0CaMBZfWa3PV/SlyX9StK8iJiQpv5BsD23ZJ0xSWM1+wRQ04zDbvtzkh6VtDYiftfti5nTImJc0nixjeqZ/AD0zYyG3mzP1lTQfxIRjxWLj9seLeqjkk70p0UATXC3aXM9dQjfJOlkRKydtvwBSf8XEffZXidpJCL+scu2zssj+7x58yrrTz75ZGX92muvrazPnj27sv7BBx+U1lavXl257v33319ZHx0drax3s2vXrtLazTffXLnu5ORkrX1nFREdT7tnchq/RNK3Jb1q+6Vi2fck3Sdpq+3vSPqNpG820CeAPuka9oj4paSyD+hfbbYdAP3C5bJAEoQdSIKwA0kQdiAJwg4k0XWcvdGdnafj7HUtX768sr5+/frK+nXXXVda63alY7f//qdOnaqsb926tbL+0EMPldb27t1buS56UzbOzpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnP0cMDIyUll/9tlnS2uPPPJIrX1v3ry5sn7o0KFa20fzGGcHkiPsQBKEHUiCsANJEHYgCcIOJEHYgSQYZwfOM4yzA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EASXcNu+wrbu2zvt/2a7e8Wy9fb/q3tl4rbsv63C6BXXS+qsT0qaTQiXrT9eUkvSFoh6W8lnYqIf5vxzrioBui7sotqZjI/+4SkieLx+7b3S7q82fYA9NtZfWa3PV/SlyX9qlh0h+1XbD9se07JOmO299jeU69VAHXM+Np425+T9N+S/iUiHrM9T9J7kkLSP2vqVP/vu2yD03igz8pO42cUdtuzJf1c0vaI2NChPl/SzyPiL7psh7ADfdbzH8J4ahrQH0naPz3oxRd3p31D0r66TQLon5l8G79U0v9IelXSJ8Xi70laKel6TZ3GH5a0uvgyr2pbHNmBPqt1Gt8Uwg70H3/PDiRH2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLrD0427D1Jb097fmmxbBgNa2/D2pdEb71qsrc/LSsM9O/ZP7Nze09ELGqtgQrD2tuw9iXRW68G1Run8UAShB1Iou2wj7e8/yrD2tuw9iXRW68G0lurn9kBDE7bR3YAA0LYgSRaCbvtm2wfsP2G7XVt9FDG9mHbrxbTULc6P10xh94J2/umLRuxvcP268V9xzn2WuptKKbxrphmvNX3ru3pzwf+md32LEkHJX1N0lFJz0taGRG/HmgjJWwflrQoIlq/AMP2X0s6Jek/Tk+tZftfJZ2MiPuKfyjnRMTdQ9Lbep3lNN596q1smvG/U4vvXZPTn/eijSP7YklvRMSbEfF7ST+VtLyFPoZeRDwj6eQZi5dL2lQ83qSp/1kGrqS3oRARExHxYvH4fUmnpxlv9b2r6Gsg2gj75ZKOTHt+VMM133tI+oXtF2yPtd1MB/NOT7NV3M9tuZ8zdZ3Ge5DOmGZ8aN67XqY/r6uNsHeammaYxv+WRMRCSTdLWlOcrmJmfijpS5qaA3BC0g/abKaYZvxRSWsj4ndt9jJdh74G8r61Efajkq6Y9vwLko610EdHEXGsuD8h6XFNfewYJsdPz6Bb3J9ouZ8/iIjjEfFxRHwiaaNafO+KacYflfSTiHisWNz6e9epr0G9b22E/XlJC2x/0fYFkr4laVsLfXyG7YuKL05k+yJJX9fwTUW9TdLtxePbJf2sxV4+ZVim8S6bZlwtv3etT38eEQO/SVqmqW/kD0n6pzZ6KOnrzyS9XNxea7s3SVs0dVo3qakzou9I+hNJOyW9XtyPDFFvmzU1tfcrmgrWaEu9LdXUR8NXJL1U3Ja1/d5V9DWQ943LZYEkuIIOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5L4fxiBM+23LOsQAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Let us plot one image\n",
    "plt.imshow(images[10].numpy().squeeze(), cmap='Greys_r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NN(\n",
       "  (fc1): Linear(in_features=784, out_features=192, bias=True)\n",
       "  (fc2): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (fc3): Linear(in_features=128, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 192)\n",
    "        self.fc2 = nn.Linear(192, 128)\n",
    "        self.fc3 = nn.Linear(128, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        ''' Forward pass through the network, returns the output logits '''        \n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "    def predict(self, x):\n",
    "        ''' To predict classes by calculating the softmax '''\n",
    "        logits = self.forward(x)\n",
    "        return F.softmax(logits, dim=1)\n",
    "\n",
    "model = NN()\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward pass and Calculate Cross Entropy Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to view the probability of classification of digit\n",
    "def view_classification(img, probs):\n",
    "    probs = probs.data.numpy().squeeze()\n",
    "    fig, (ax1, ax2) = plt.subplots(figsize=(6,7), ncols=2)\n",
    "    ax1.imshow(img.numpy().squeeze())\n",
    "    ax1.axis('off')\n",
    "    ax2.barh(np.arange(10), probs)\n",
    "    ax2.set_aspect(0.1)\n",
    "    ax2.set_yticks(np.arange(10))\n",
    "    ax2.set_yticklabels(np.arange(10).astype(int), size='large');\n",
    "    ax2.set_title('Probability')\n",
    "    ax2.set_xlim(0, 1.1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADECAYAAAA8lvKIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAASN0lEQVR4nO3de7RcZXnH8e8vJykxCQlggCYBEhQCXtCIBwtYILbQSFILWGyjiNYqsVxs7VKW1IXITSnLVbWAt3CTq1QULyUBRKkXoKAnGhQswQgJhHALl8BJSIDD0z/2Pstxzp6TmZyZvWfv8/usNStn3v2++zyZzHrWm/fd+9mKCMzMLB9jig7AzGw0cdI1M8uRk66ZWY6cdM3McuSka2aWIyddM7McOemajTKSQtIeWzl2laRDGxw7SNKKrL6SPinpoq2LuFrGFh2AmTVH0ipgZ2AA2AAsBT4SEf1FxjUoIn4G7NXg2GcHf5Y0C3gAGBcRL+UTXffwTNesXN4REZOAfYH9gFNrD0ryRKrLOemalVBEPAzcALw+XS44UdLvgN8BSDpO0kpJT0n6vqTpdaeYL+l+SeskfU7SmHTcqyXdIunJ9NhVkrarG7ufpN9KelrSpZLGp2PnSlqTFa+k0yVdmb79afrnM5L6JR2SxrlPTf+dJD0vaceRfE7dyEnXrIQk7QrMB36VNh0J/BnwWkl/AZwD/B0wDVgNXFN3iqOAXpIZ8xHAPw6eOh07HXgNsCtwet3YY4B5wKuB2dTNtptwcPrndhExKSJ+ksb33po+7wZ+GBFPtHjurueka1Yu35X0DHAr8BNgcK30nIh4KiKeJ0mKl0TELyNiM/BvwAHpWuqgc9P+DwJfJElyRMTKiLg5IjanCe/zwCF1MVwQEQ9FxFPAZwbHjtBlwHsGZ9zAscAVbThv1/H6j1m5HBkRP6xtkATwUE3TdOCXg28iol/Sk8AMYFXaXNt/dToGSTsB5wEHAduSTMyeroshc+xIRMSdkjYAh0h6BNgD+P5Iz9uNPNM1q4bacoFrgZmDbyRNBF4JPFzTZ9ean3dLx0CytBDAGyJiMsl/+VX3uxqN3ZpYa12W/r5jgW9FxKYWz1sKTrpm1XM18AFJcyRtQ7IEcWdErKrpc7Kk7dO14X8B/itt3xboJ9nkmgGcnHH+EyXtImkH4JM1Y5v1BPAy8Kq69itI1prfC1ze4jlLw0nXrGIi4kfAp4BvA4+QbHgtrOv2PWAZsBxYAlyctp9Bsrm2Pm2/LuNXXA38ALg/fZ3dYnwbSdaCb5P0jKT90/Y1JMsiAfyslXOWiVzE3My6haRLgLUR0eoVEaXhjTQz6wrp1RXvBN5UcCgd5eUFMyucpLOAu4HPRcQDRcfTSV5eMDPL0bDLC4eNeZczsnXUzS9fW385klmleXnBzCxH3kizUWnq1Kkxa9asosOwilq2bNm6iMgs1uOka6PSrFmz6OvrKzoMqyhJqxsd8/KCmVmOnHTNzHLkpGtmliMnXTOzHHkjzUal3zy8nlmnLGlpzKp/X9ChaGw08UzXzCxHTrpmZjly0rXSk/Sa9Am269Mn4B5VdExmjTjpWqlJGktSkPt6YAdgEXClpNmFBmbWgJOuld3eJA9G/EJEDETELcBtJM/ZMus6TrpWdllVygS8fkijtEhSn6S+gY3rOx+ZWQYnXSu7e4HHSR60OE7SXwGHABPqO0bE4ojojYjenglT8o7TDHDStZKLiBeBI4EFwKPAx4BvAmsKDMusId8cYaUXEb8mmd0CIOl24LLiIjJrzDNdKz1Jb5A0XtIESR8HpgFfLzgss0xOulYFxwKPkKzt/iVwWERsLjYks2xeXrDSi4iTgZNbGbPPjCn0uZaCFcBJd5To2WuPIW33H5P5NBH+4z2XDmk7b4+92x6T2WjkpGuj0tZUGRvkamM2El7TNTPLkZOulZ6kWZKWSnpa0qOSLkhrMph1HSddq4Ivk1y5MA2YQ3LN7glFBmTWiGcDFXPfhftltj+w4MKmzzFv+pw2RZOb3YELImIT8KikG4HXFRyTWSbPdK0K/hNYmN4cMQM4HLix4JjMMjnpWhX8hGRm+yxJzYU+4Lv1nVxlzLqBk66VmqQxwE3AdcBEYCqwPXBufV9XGbNu4KRrZbcDsCvJmu7miHgSuBSYX2xYZtmcdK3UImId8ABwvKSxkrYD3g/cVWhgZg346oUSyLqFF2CPq1YPabtpevZVCvsvP3pI25T5K0cWWPd4J/BF4BPAAPA/wL8WGZBZI066VnoRsRyY28oYF7yxonh5wcwsR066ZmY58vKCjUquMmZFcdLtMlmbZicuuT6z74IJm4a07b7kuMy+s4/7xcgCM7O28PKClZqk/rrXgKTzi47LrBHPdK3UImLS4M+SJgKPAdcWF5HZ8DzTtSo5mqTE48+KDsSsESddq5L3A5dHRBQdiFkjTrpWCZJ2IyleftkwfVxlzArnNd2CNLq1N+tKhayrFADefMbxQ9pmf+1/RxZYeb0PuDUiHmjUISIWA4sBtpm2p2fDVgjPdK0q3scws1yzbuGka6Un6UBgBr5qwUrASdeq4P3AdRHxXNGBmG2J13St9CLiw62OcZUxK0rlk26jDaunvjC0rVP1ZUd6a2/WhhnA1NG7aWZWWl5eMDPLUeVnumZZtrbKmCuM2Uh5pmtmliMnXasESQsl/Z+kDZJ+L+mgomMyy+LlBSs9SYcB5wJ/D/wcmFZsRGaNVT7ptnKVwDzm5BZDo1t7975o6JUKM32VwpacAZwZEXek7x8uMhiz4Xh5wUpNUg/QC+woaaWkNZIukPSKjL4ueGOFc9K1stsZGEdSS/cgYA7wJuDU+o4RsTgieiOit2fClFyDNBvkpGtl93z65/kR8UhErAM+D8wvMCazhpx0rdQi4mlgDeBSjVYKldpIW/fhA4a0LZiwPLPv/suPHtI2hZHdBrxX37jM9pae2nuaN822wqXARyTdCLwIfBTI3kE1K1ilkq6NWmcBU4H7gE3AN4HPFBqRWQNOulZ6EfEicEL6aoqrjFlRvKZrZpYjJ10zsxx5ecFGpVarjLm6mLVLpZLusk9/pem+48/ffkS/a/3SoYXJz5v+rcy+SzaOH9I2+7hfjOj3m1k5eXnBSk/SjyVtktSfvlYUHZNZI066VhUnRcSk9LVX0cGYNeKka2aWIyddq4pzJK2TdJukuVkdXGXMukEpN9Ia3W7bih9ffOEIz7B8RKOzNuIaeezh7E2/bR4Z+s83sQ2VZEv4lOFPAL8FXgAWAv8taU5E/L62U0QsBhYDbDNtT9dqsEJ4pmulFxF3RsRzEbE5Ii4DbsNVxqxLOelaFQWgooMwy+Kka6UmaTtJ8ySNlzRW0jHAwcBNRcdmlqWUa7pmNcYBZwN7AwPAvcCREeFrda0rKaLxfsJhY95V+GbD6jOH1si990PN33lm7TFv+pyOnPfml68tZBmgt7c3+vr6ivjVNgpIWhYRvVnHvLxgZpYjJ10zsxx5TddGpVarjNVyxTEbCc90zcxy5KRrlSFpz7Ta2JVFx2LWSNcvL/hKBWvBlwAXKrau5pmuVYKkhcAzwI8KDsVsWE66VnqSJgNnAh/bQj9XGbPCOelaFZwFXBwRDw3XKSIWR0RvRPT2TJiSU2hmf6zr13TNhiNpDnAo8KaCQzFrStcn3flvO3pI22Nzdxzxec88+dIhbQsmbGp6/NwPHjekbZsbsvdwevYaWju3HX+HLOOOeKIj553Cyo6ctw3mArOAByUBTAJ6JL02IvYtMC6zTF2fdM22YDFwTc37j5Mk4eMLicZsC5x0rdQiYiOwcfC9pH5gU0R0ZspvNkJOulYpEXF6M/32mTGFPt/OawXw1QtmZjnyTNdGpa0peONCN9YOXZ90B1YM3TWfmtHWyObD98tsb+VKhb0vGronM/OG5p+YO9K/Q0u+1pnTmll7eHnBzCxHTrpWepKulPSIpGcl3SfpQ0XHZNaIk65VwTnArIiYDPwNcLakNxcck1kmJ10rvYi4JyI2D75NX68uMCSzhrp+I22kfnzxhU33/ee12ZtuM09rftPMiiHpy8A/AK8AfgUszeizCFgE0DO5M7dhm22JZ7pWCRFxArAtcBBwHbA5o4+rjFnhnHStMiJiICJuBXbBtResSznpWhWNxWu61qWcdK3UJO0kaaGkSZJ6JM0D3g3cUnRsZlkqv5FmlRckSwlfJZlErAY+GhHfKzQqswYqlXTXLx1aLByWNz1+Re+LbYvF8pGWcDyk1XGuMmZF8fKCmVmOKjXTNWtWq1XGXGHM2sUzXTOzHDnpWqlJ2kbSxZJWS3pO0q8kHV50XGaNlHJ5IevpugB3zPlW0+fIrJGLb/ctobHAQySbaQ8C84FvStonIlYVGZhZllImXbNBEbEBOL2m6XpJDwBvBlYVEZPZcLy8YJUiaWdgNnBP0bGYZXHStcqQNA64CrgsIu7NOL5IUp+kvoGN6/MP0AwnXasISWOAK4AXgJOy+rjKmHUDr+la6UkScDGwMzA/InxroXWtUibdPa5a3XTfJRvHZ7a7MHmlfAV4DXBoRDxfdDBmw/HygpWapJnAh4E5wKOS+tPXMcVGZpatlDNds0ERsRpQq+Nc8MaK4pmumVmOnHTNzHLU9csLWbf8nje9+dt9v7TgrxscWbmVEVkVtFJlzBXGrJ080zUzy5GTrpWepJPSO802S/p60fGYDafrlxfMmrAWOBuYB7yi4FjMhuWka6UXEdcBSOoFdik4HLNheXnBzCxHXT/T3fiq7Zvum3XL78AKX6VgCUmLgEUAPZN3LDgaG60807VRw1XGrBs46ZqZ5ajrlxfMtkTSWJLvcg/QI2k88FJEvFRsZGZDeaZrVXAq8DxwCvDe9OdTC43IrAFFRMODh415V+ODZm1w88vXtlwhrB16e3ujr6+viF9to4CkZRHRm3XMM10zsxw56ZqZ5cgbaTYqtVJlLIsrj9nW8kzXzCxHTrpWepJ2kPQdSRskrZb0nqJjMmvEywtWBV8CXiB5BPscYImkuyLinkKjMsvgma6VmqSJwN8Cn4qI/oi4Ffg+cGyxkZllc9K1spsNDETEfTVtdwGvq+8oaVFa7LxvYOP63AI0q+Wka2U3CajPoOuBbes7uuCNdQMnXSu7fmByXdtk4LkCYjHbIiddK7v7gLGS9qxpeyPgTTTrSk66VmoRsQG4DjhT0kRJbwWOAK4oNjKzbE66VgUnkDyQ8nHgG8DxvlzMupWv07XSi4ingCNbGbPPjCn0+VZeK4BnumZmOXLSNTPLkZOumVmOnHTNzHLkpGtmliMnXTOzHPmSMRuVli1b1i9pRdFxAFOBdUUHkXIsQ21tHDMbHRj2acBmVSWpr9HTWkdjHOBY8orDywtmZjly0jUzy5GTro1Wi4sOINUtcYBjydL2OLyma2aWI890zcxy5KRrlSLp7ZJWSFop6ZSM45J0Xnr815L2bXZsB2I5Jo3h15Jul/TGmmOrJP1G0nJJfR2OY66k9envWi7ptGbHdiCWk2viuFvSgKQd0mPt/EwukfS4pLsbHO/c9yQi/PKrEi+gB/g98CrgT0geUPnauj7zgRsAAfsDdzY7tgOxHAhsn/58+GAs6ftVwNScPpO5wPVbM7bdsdT1fwdwS7s/k/RcBwP7Anc3ON6x74lnulYlbwFWRsT9EfECcA3JUyRqHQFcHok7gO0kTWtybFtjiYjbI+Lp9O0dwC4j+H1bHUeHxrbjfO8mKUrfdhHxU+CpYbp07HvipGtVMgN4qOb9mrStmT7NjG13LLU+SDKzGhTADyQtk7QohzgOkHSXpBskDT6+vrDPRNIE4O3At2ua2/WZNKNj3xPfBmxVooy2+stzGvVpZmy7Y0k6Sm8jSbp/XtP81ohYK2kn4GZJ96azs07E8UtgZkT0S5oPfBfYs8mx7Y5l0DuA2yJ5Ksigdn0mzejY98QzXauSNcCuNe93AdY22aeZse2OBUlvAC4CjoiIJwfbI2Jt+ufjwHdI/lvbkTgi4tmI6E9/XgqMkzS12b9DO2OpsZC6pYU2fibN6Nz3pB2L0n751Q0vkv+53Q/szh82OV5X12cBf7xB8vNmx3Yglt2AlcCBde0TgW1rfr4deHsH4/hT/nDN/luAB9PPJ/fPJO03hWS9dWInPpOac86i8UZax74nXl6wyoiIlySdBNxEsst8SUTcI+mf0uNfBZaS7EyvBDYCHxhubIdjOQ14JfBlSQAvRVJcZWfgO2nbWODqiLixg3EcDRwv6SXgeWBhJBmmiM8E4CjgBxGxoWZ42z4TAEnfILlqY6qkNcCngXE1cXTse+I70szMcuQ1XTOzHDnpmpnlyEnXzCxHTrpmZjly0jUzy5GTrplZjpx0zcxy5KRrZpaj/wcxwWV6qx+57gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x504 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "images.resize_(128, 784)\n",
    "\n",
    "# Forward pass through the network\n",
    "img_idx = 0\n",
    "logits = model.forward(images)\n",
    "\n",
    "# Predict the class from the network output\n",
    "prediction = F.softmax(logits, dim=1)\n",
    "\n",
    "img = images[0].data\n",
    "view_classification(img.reshape(1, 28, 28), prediction[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: MNIST_data/\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               ToTensor()\n",
       "               Normalize(mean=0.5, std=0.5)\n",
       "           )"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainloader.dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Back Propagation\n",
    "We need to now train the network to adjust its weights by first calculating Cross Entropy Loss and then back propagating the error to adjust weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an optimizer to train the network by carrying out back propagation\n",
    "model = NN()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/1 Loss: 2.0406 Test accuracy: 0.5407\n",
      "Epoch: 1/1 Loss: 1.3468 Test accuracy: 0.6823\n",
      "Epoch: 1/1 Loss: 0.8910 Test accuracy: 0.7483\n",
      "Epoch: 1/1 Loss: 0.7174 Test accuracy: 0.8199\n",
      "Epoch: 1/1 Loss: 0.5849 Test accuracy: 0.8447\n",
      "Epoch: 1/1 Loss: 0.4994 Test accuracy: 0.8663\n",
      "Epoch: 1/1 Loss: 0.4296 Test accuracy: 0.8682\n",
      "Epoch: 1/1 Loss: 0.4548 Test accuracy: 0.8636\n",
      "Epoch: 1/1 Loss: 0.4076 Test accuracy: 0.8898\n",
      "Epoch: 1/1 Loss: 0.3213 Test accuracy: 0.8802\n",
      "Epoch: 1/1 Loss: 0.4445 Test accuracy: 0.8952\n",
      "Epoch: 1/1 Loss: 0.3564 Test accuracy: 0.8916\n",
      "Epoch: 1/1 Loss: 0.4196 Test accuracy: 0.9016\n",
      "Epoch: 1/1 Loss: 0.3701 Test accuracy: 0.8919\n",
      "Epoch: 1/1 Loss: 0.3376 Test accuracy: 0.8896\n",
      "Epoch: 1/1 Loss: 0.3350 Test accuracy: 0.8972\n",
      "Epoch: 1/1 Loss: 0.3595 Test accuracy: 0.9022\n",
      "Epoch: 1/1 Loss: 0.3318 Test accuracy: 0.9055\n",
      "Epoch: 1/1 Loss: 0.3191 Test accuracy: 0.8992\n",
      "Epoch: 1/1 Loss: 0.3498 Test accuracy: 0.9057\n",
      "Epoch: 1/1 Loss: 0.3106 Test accuracy: 0.9134\n",
      "Epoch: 1/1 Loss: 0.3212 Test accuracy: 0.9159\n",
      "Epoch: 1/1 Loss: 0.3087 Test accuracy: 0.9036\n",
      "Epoch: 1/1 Loss: 0.3591 Test accuracy: 0.9068\n",
      "Epoch: 1/1 Loss: 0.2975 Test accuracy: 0.9124\n",
      "Epoch: 1/1 Loss: 0.3008 Test accuracy: 0.9159\n",
      "Epoch: 1/1 Loss: 0.3148 Test accuracy: 0.9064\n",
      "Epoch: 1/1 Loss: 0.3477 Test accuracy: 0.9170\n",
      "Epoch: 1/1 Loss: 0.2940 Test accuracy: 0.9152\n",
      "Epoch: 1/1 Loss: 0.3043 Test accuracy: 0.9205\n",
      "Epoch: 1/1 Loss: 0.2912 Test accuracy: 0.9163\n",
      "Epoch: 1/1 Loss: 0.2933 Test accuracy: 0.9229\n",
      "Epoch: 1/1 Loss: 0.2661 Test accuracy: 0.9192\n",
      "Epoch: 1/1 Loss: 0.2395 Test accuracy: 0.9259\n",
      "Epoch: 1/1 Loss: 0.2635 Test accuracy: 0.9183\n",
      "Epoch: 1/1 Loss: 0.3114 Test accuracy: 0.9203\n",
      "Epoch: 1/1 Loss: 0.2506 Test accuracy: 0.9314\n",
      "Epoch: 1/1 Loss: 0.2874 Test accuracy: 0.9228\n",
      "Epoch: 1/1 Loss: 0.2506 Test accuracy: 0.9277\n",
      "Epoch: 1/1 Loss: 0.2428 Test accuracy: 0.9329\n",
      "Epoch: 1/1 Loss: 0.2528 Test accuracy: 0.9267\n",
      "Epoch: 1/1 Loss: 0.2909 Test accuracy: 0.9219\n",
      "Epoch: 1/1 Loss: 0.2477 Test accuracy: 0.9283\n",
      "Epoch: 1/1 Loss: 0.2585 Test accuracy: 0.9335\n",
      "Epoch: 1/1 Loss: 0.2373 Test accuracy: 0.9287\n",
      "Epoch: 1/1 Loss: 0.2348 Test accuracy: 0.9276\n"
     ]
    }
   ],
   "source": [
    "# Train network\n",
    "\n",
    "epochs = 1\n",
    "steps = 0\n",
    "running_loss = 0\n",
    "eval_freq = 10\n",
    "for e in range(epochs):\n",
    "    for images, labels in iter(trainloader):\n",
    "        steps += 1\n",
    "        images.resize_(images.size()[0], 784)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output = model.forward(images)\n",
    "        loss = loss_fn(output, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "        \n",
    "        if steps % eval_freq == 0:\n",
    "            # Test accuracy\n",
    "            accuracy = 0\n",
    "            for ii, (images, labels) in enumerate(testloader):\n",
    "                \n",
    "                images = images.resize_(images.size()[0], 784)                \n",
    "                predicted = model.predict(images).data\n",
    "                equality = (labels == predicted.max(1)[1])\n",
    "                accuracy += equality.type_as(torch.FloatTensor()).mean()\n",
    "            \n",
    "            print(\"Epoch: {}/{}\".format(e+1, epochs),\n",
    "                  \"Loss: {:.4f}\".format(running_loss/eval_freq),\n",
    "                  \"Test accuracy: {:.4f}\".format(accuracy/(ii+1)))\n",
    "            running_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADECAYAAAA8lvKIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR8ElEQVR4nO3de7QdZX3G8e9DkhKSQAATKAkkESEgCAY8WNQisYVGklrQYhtuWqvEcrG1S1lSFyICSl2uquXiJdxEQC0oICXhphTkUtATDQiWYIAEQkASLpGTkADh1z9mDmz2nn2yk7P3zJ45z2etvXL2O+/MebKz12+9eWfmHUUEZmaWj82KDmBmNpS46JqZ5chF18wsRy66ZmY5ctE1M8uRi66ZWY5cdM2GGEkhaZdN3HeJpIOabDtA0qKsvpI+L+mCTUtcLcOLDmBmrZG0BNgeWA+sBuYDn4qIviJz9YuI24Hdmmz7Sv/PkqYAjwIjIuKVfNJ1D490zcrlAxExBtgX2A84pXajJA+kupyLrlkJRcQTwPXA29LpghMk/R74PYCkYyUtlvSspGslTag7xExJj0haKelrkjZL93uLpFskPZNuu1zS1nX77ifpd5Kek3SxpJHpvtMlLcvKK+k0SZelb3+R/vm8pD5JB6Y596rpv52kFyWNH8zn1I1cdM1KSNJOwEzgN2nTYcCfAXtI+gvgLODvgB2ApcCP6g7xQaCHZMR8KPCP/YdO950AvBXYCTitbt+jgBnAW4Cp1I22W/De9M+tI2JMRNyW5ju6ps8RwM8iYsVGHrvrueialcs1kp4H7gBuA/rnSs+KiGcj4kWSonhRRPw6ItYB/wa8K51L7ffVtP9jwDdJihwRsTgibo6IdWnB+zpwYF2GcyPi8Yh4Fvhy/76DdAlwZP+IGzgGuLQNx+06nv8xK5fDIuJntQ2SAB6vaZoA/Lr/TUT0SXoGmAgsSZtr+y9N90HSdsDZwAHAliQDs+fqMmTuOxgRcY+k1cCBkp4EdgGuHexxu5FHumbVULtc4HJgcv8bSaOBNwFP1PTZqebnSek+kEwtBLB3RGxF8l9+1f2uZvtuStZal6S/7xjgxxGxdiOPWwouumbV8wPgY5KmSdqcZArinohYUtPnJEnbpHPD/wL8V9q+JdBHcpJrInBSxvFPkLSjpG2Bz9fs26oVwKvAznXtl5LMNR8NfH8jj1kaLrpmFRMRPwe+APwEeJLkhNfsum4/BRYAC4F5wIVp+5dITq6tStuvyvgVPwBuAh5JX2duZL41JHPBd0p6XtL+afsykmmRAG7fmGOWibyIuZl1C0kXAcsjYmOviCgNn0gzs66QXl3xIWCfgqN0lKcXzKxwks4A7ge+FhGPFp2nkzy9YGaWowGnFw7e7MOuyNZRN796Zf3lSGaV5ukFM7Mc+USaDUnjxo2LKVOmFB3DKmrBggUrIyJzsR4XXRuSpkyZQm9vb9ExrKIkLW22zdMLZmY5ctE1M8uRi66ZWY5cdM3McuSia2aWIxddM7McueiameXIRddKT9Jb0yfYrkqfgPvBojOZNeOia6UmaTjJgtzXAdsCc4DLJE0tNJhZEy66Vna7kzwY8RsRsT4ibgHuJHnOllnXcdG1sstapUzA2xoapTmSeiX1rlixovPJzDK46FrZPQg8TfKgxRGS/go4EBhV3zEi5kZET0T0jB+fuRaJWce56FqpRcTLwGHALOAp4DPAFcCyAmOZNeVVxqz0IuI+ktEtAJLuAi4pLpFZcx7pWulJ2lvSSEmjJH0W2AH4XsGxzDK56FoVHAM8STK3+5fAwRGxrthIZtk8vWClFxEnAScVncOsFS66Q8Sw3XZpaHvkqOwz+P9x5MUNbWfvsnvbM5kNRZ5eMDPLkYuumVmOXHSt9CRNkTRf0nOSnpJ0bromg1nXcdG1KvgWyZULOwDTSK7ZPb7IQGbNeDRQMQ+dv19m+6Ozzm/5GDMmTGtTmty8GTg3ItYCT0m6Adiz4ExmmTzStSr4T2B2enPEROAQ4IaCM5llctG1KriNZGT7R5I1F3qBa+o7eZUx6wYuulZqkjYDbgSuAkYD44BtgK/W9/UqY9YNXHSt7LYFdiKZ010XEc8AFwMzi41lls1F10otIlYCjwLHSRouaWvgo8C9hQYza8JXL5RA1i28ALtcvrSh7cYJ2Vcp7L/w8Ia2sTMXDy5Y9/gQ8E3gc8B64H+Afy0ykFkzLrpWehGxEJhecAyzlnh6wcwsRy66ZmY5ctE1M8uR53S7TNZJsxPmXZfZd9aotQ1tb553bGbfqcf+anDBzKwtPNK1UpPUV/daL+mconOZNeORrpVaRIzp/1nSaOAPwJXFJTIbmEe6ViWHkyzxeHvRQcyacdG1Kvko8P2IiKKDmDXjomuVIGkSyeLllwzQx6uMWeE8p1uQZrf2Zl2pkHWVAsA7vnRcQ9vU7/7v4IKV10eAOyLi0WYdImIuMBegp6fHo2ErhEe6VhUfYYBRrlm3cNG10pP0bmAivmrBSsBF16rgo8BVEfFC0UHMNsRzulZ6EfHJojOYtaryRbfZCatnv9HY1qn1ZQd7a2/WCTOAcUP3pJlZaXl6wcwsR5Uf6Zpl+e0Tq5hy8rzX3i/591kFprGhxCNdM7McuehaJUiaLen/JK2W9LCkA4rOZJbF0wtWepIOBr4K/D3wS2CHYhOZNVf5orsxVwnMYFpuGZrd2rv7BY1XKkz2VQob8iXg9Ii4O33/RJFhzAbi6QUrNUnDgB5gvKTFkpZJOlfSFhl9X1vwZv2aVfmHNcNF18pve2AEyVq6BwDTgH2AU+o7RsTciOiJiJ5ho8bmGtKsn4uuld2L6Z/nRMSTEbES+Dows8BMZk256FqpRcRzwDLASzVaKVTqRNrKT76roW3WqIWZffdfeHhD21gGdxvwbr0jMts36qm9p/qk2Sa4GPiUpBuAl4FPA9lnUM0KVqmia0PWGcA44CFgLXAF8OVCE5k14aJrpRcRLwPHp6+W7DVxLL2+9dcK4DldM7McueiameXIRdeGpPpVxszyUqk53QVf/HbLfUees82gfteq+Y0Lk5894ceZfeetGdnQNvXYXw3q95tZOXmka6Un6VZJayX1pa9FRWcya8ZF16rixIgYk752KzqMWTMuumZmOXLRtao4S9JKSXdKmp7VwauMWTco5Ym0ZrfbboxbLzx/kEdYOKi9s07ENfOHJ7JP+m3+ZOM/3+g2rCRbwqcMfw74HfASMBv4b0nTIuLh2k4RMReYC7D5Drt6rQYrhEe6VnoRcU9EvBAR6yLiEuBOvMqYdSkXXauiAFR0CLMsLrpWapK2ljRD0khJwyUdBbwXuLHobGZZSjmna1ZjBHAmsDuwHngQOCwifK2udSVFND+fcPBmHy78ZMPS0xvXyH3wE63feWbtMWPCtI4c9+ZXryxkGqCnpyd6e3uL+NU2BEhaEBE9Wds8vWBmliMXXTOzHHlO14YkrzJmAEsKWMjeI10zsxy56FplSNo1XW3ssqKzmDXT9dMLvlLBNsJ5gBcqtq7mka5VgqTZwPPAzwuOYjYgF10rPUlbAacDn9lAP68yZoVz0bUqOAO4MCIeH6hTRMyNiJ6I6Bk2amxO0czeqOvndM0GImkacBCwT8FRzFrS9UV35vsOb2j7w/Txgz7u6Sdd3NA2a9Talvef/vFjG9o2vz77HM6w3RrXzm3H3yHLiENXdOS4Y1nckeO2wXRgCvCYJIAxwDBJe0TEvgXmMsvU9UXXbAPmAj+qef9ZkiJ8XCFpzDbARddKLSLWAGv630vqA9ZGRGeG/GaD5KJrlRIRp7XSb6+JY+kt4BZQM1+9YGaWIxddM7Mcdf30wvpFjWfNx2W0NbPukP0y2zfmSoXdL2g8JzP5+tafmDvYv8NG+W5nDmtm7eGRrplZjlx0rfQkXSbpSUl/lPSQpE8UncmsGRddq4KzgCkRsRXwN8CZkt5RcCazTC66VnoR8UBErOt/m77eUmAks6a6/kTaYN164fkt9/3n5dkn3Saf2vpJMyuGpG8B/wBsAfwGmJ/RZw4wB2DSpEl5xjN7jUe6VgkRcTywJXAAcBWwLqPPa6uMjR/fmbUvzDbERdcqIyLWR8QdwI547QXrUi66VkXD8ZyudSkXXSs1SdtJmi1pjKRhkmYARwC3FJ3NLEvlT6RZ5QXJVMJ3SAYRS4FPR8RPC01l1kSliu6q+Y2LhcPClvdf1PNy27JYPtIlHA8sOodZqzy9YGaWIxddM7McueiameXIRddKTdLmki6UtFTSC5J+I+mQonOZNVPKE2lZT9cFuHvaj1s+RuYaufh23xIaDjxOcjLtMWAmcIWkvSJiSZHBzLKUsuia9YuI1cBpNU3XSXoUeAewpIhMZgPx9IJViqTtganAA0VnMcviomuVIWkEcDlwSUQ8mLF9jqReSb0rVvgJ7VYMF12rBEmbAZcCLwEnZvXxKmPWDTyna6UnScCFwPbAzIjwrYXWtUpZdHe5fGnLfeetGZnZ7oXJK+XbwFuBgyLixaLDmA3E0wtWapImA58EpgFPSepLX0cVm8wsWylHumb9ImIpoKJzmLXKI10zsxy56JqZ5ajrpxeybvk9e0Lrt/ueN+uvm2xZvImJzMw2nUe6ZmY5ctG10pN0Ynqn2TpJ3ys6j9lAun56wawFy4EzgRnAFgVnMRuQi66VXkRcBSCpB9ix4DhmA/L0gplZjrp+pLtm521a7pt1y+/6Rb5KwRKS5gBzACZNmlRwGhuqPNK1IcOrjFk3cNE1M8tR108vmG2IpOEk3+VhwDBJI4FXIuKVYpOZNfJI16rgFOBF4GTg6PTnUwpNZNZE1490N7/+Vw1tMyZMyz+Ida2IOI03PpzSrGt5pGtmliMXXTOzHLnompnlyEXXzCxHLrpWepK2lXS1pNWSlko6suhMZs10/dULZi04D3iJ5BHs04B5ku6NiAcKTWWWwSNdKzVJo4G/Bb4QEX0RcQdwLXBMscnMsrnoWtlNBdZHxEM1bfcCe9Z3lDQnXey8d8WKFbkFNKvlomtlNwZYVde2CtiyvqMXvLFu4KJrZdcHbFXXthXwQgFZzDbIRdfK7iFguKRda9reDvgkmnUlF10rtYhYDVwFnC5ptKT3AIcClxabzCybi65VwfEkD6R8GvghcJwvF7Nu5et0rfQi4lngsKJzmLXCI10zsxy56JqZ5chF18wsRy66ZmY5ctE1M8uRi66ZWY58yZgNSQsWLOiTtKjoHMA4YGXRIVLO0mhTc0xutkERselxzEpKUm9E9DjH65wlnxyeXjAzy5GLrplZjlx0baiaW3SAVLfkAGfJ0vYcntM1M8uRR7pmZjly0bVKkfR+SYskLZZ0csZ2STo73X6fpH1b3bcDWY5KM9wn6S5Jb6/ZtkTSbyUtlNTb4RzTJa1Kf9dCSae2um8HspxUk+N+SeslbZtua+dncpGkpyXd32R7574nEeGXX5V4AcOAh4GdgT8heUDlHnV9ZgLXAwL2B+5pdd8OZHk3sE368yH9WdL3S4BxOX0m04HrNmXfdmep6/8B4JZ2fybpsd4L7Avc32R7x74nHulalbwTWBwRj0TES8CPSJ4iUetQ4PuRuBvYWtIOLe7b1iwRcVdEPJe+vRvYcRC/b5NzdGjfdhzvCJJF6dsuIn4BPDtAl459T1x0rUomAo/XvF+WtrXSp5V9252l1sdJRlb9ArhJ0gJJc3LI8S5J90q6XlL/4+sL+0wkjQLeD/ykprldn0krOvY98W3AViXKaKu/PKdZn1b2bXeWpKP0PpKi++c1ze+JiOWStgNulvRgOjrrRI5fA5Mjok/STOAaYNcW9213ln4fAO6M5Kkg/dr1mbSiY98Tj3StSpYBO9W83xFY3mKfVvZtdxYk7Q1cABwaEc/0t0fE8vTPp4GrSf5b25EcEfHHiOhLf54PjJA0rtW/Qzuz1JhN3dRCGz+TVnTue9KOSWm//OqGF8n/3B4B3szrJzn2rOszizeeIPllq/t2IMskYDHw7rr20cCWNT/fBby/gzn+lNev2X8n8Fj6+eT+maT9xpLMt47uxGdSc8wpND+R1rHviacXrDIi4hVJJwI3kpxlvigiHpD0T+n27wDzSc5MLwbWAB8baN8OZzkVeBPwLUkAr0SyuMr2wNVp23DgBxFxQwdzHA4cJ+kV4EVgdiQVpojPBOCDwE0Rsbpm97Z9JgCSfkhy1cY4ScuALwIjanJ07HviO9LMzHLkOV0zsxy56JqZ5chF18wsRy66ZmY5ctE1M8uRi66ZWY5cdM3McuSia2aWo/8HwTwoVfUuesQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x504 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "logits = model.forward(img[None,])\n",
    "\n",
    "# Predict the class from the network output\n",
    "prediction = F.softmax(logits, dim=1)\n",
    "\n",
    "view_classification(img.reshape(1, 28, 28), prediction[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that model is able to correctly predict the digit after training while before training it predicting all digits with almost equal probability i.e. it was randomly predicting the digit."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
